{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3GvsqHtXQvvX"
   },
   "source": [
    "# Initial Setups\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MTrRCzgmFwYn"
   },
   "source": [
    "## (Google Colab use only)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xC1S4mWFFv5U"
   },
   "outputs": [],
   "source": [
    "# Use Google Colab\n",
    "use_colab = True\n",
    "\n",
    "# Is this notebook running on Colab?\n",
    "# If so, then google.colab package (github.com/googlecolab/colabtools)\n",
    "# should be available in this environment\n",
    "\n",
    "# Previous version used importlib, but we could do the same thing with\n",
    "# just attempting to import google.colab\n",
    "try:\n",
    "    from google.colab import drive\n",
    "    colab_available = True\n",
    "except:\n",
    "    colab_available = False\n",
    "\n",
    "if use_colab and colab_available:\n",
    "    drive.mount('/content/drive')\n",
    "\n",
    "    # cd to the appropriate working directory under my Google Drive\n",
    "    %cd 'drive/My Drive/cs696ds_lexalytics/Language Model Finetuning'\n",
    "    \n",
    "    # Install packages specified in requirements\n",
    "    !pip install -r requirements.txt\n",
    "    \n",
    "    # List the directory contents\n",
    "    !ls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uZ5SZXkbEp34"
   },
   "source": [
    "## Experiment Parameters\n",
    "\n",
    "**NOTE**: The following `experiment_id` MUST BE CHANGED in order to avoid overwriting the files from other experiments!!!!!!\n",
    "\n",
    "**NOTE 2**: The values for the variables in the cell below can be overridden by `papermill` at runtime. Variables in other cells cannot be changed in this manner."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rUqrV6VeEs3Z",
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# We will use the following string ID to identify this particular (training) experiments\n",
    "# in directory paths and other settings\n",
    "experiment_id = 'lm_further_pretraining_bert_amazon_electronics'\n",
    "\n",
    "# Random seed\n",
    "random_seed = 696\n",
    "\n",
    "# Dataset size related\n",
    "total_subset_proportion = 1.0 # Do we want to use the entirety of the training set, or some parts of it?\n",
    "validation_dataset_proportion = 0.1 # Proportion to be reserved for validation (after selecting random subset with total_subset_proportion)\n",
    "\n",
    "# Training hyperparameters\n",
    "num_train_epochs = 20 # Number of epochs\n",
    "per_device_train_batch_size = 16 # training batch size PER COMPUTE DEVICE\n",
    "per_device_eval_batch_size = 16 # evaluation batch size PER COMPUTE DEVICE\n",
    "learning_rate = 1e-5\n",
    "weight_decay = 0.01\n",
    "\n",
    "# Settings for checkpoint resumption\n",
    "# Provide a string of relative path to transformers.Trainer compatible checkpoint.\n",
    "# If None, then the training will start from scratch.\n",
    "checkpoint_path = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kSItEk35-R8o"
   },
   "source": [
    "## Package Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "o-jSRWQfLL4U"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import transformers\n",
    "import datasets\n",
    "\n",
    "import utils\n",
    "\n",
    "# Random seed settings\n",
    "random.seed(random_seed)\n",
    "np.random.seed(random_seed)\n",
    "torch.manual_seed(random_seed)\n",
    "\n",
    "# Print version information\n",
    "print(\"Python version: \" + sys.version)\n",
    "print(\"NumPy version: \" + np.__version__)\n",
    "print(\"PyTorch version: \" + torch.__version__)\n",
    "print(\"Transformers version: \" + transformers.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rkKDoXUp-UIi"
   },
   "source": [
    "## PyTorch GPU settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "je9BT2pQIpUx"
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():    \n",
    "    torch_device = torch.device('cuda')\n",
    "\n",
    "    # Set this to True to make your output immediately reproducible\n",
    "    # Note: https://pytorch.org/docs/stable/notes/randomness.html\n",
    "    torch.backends.cudnn.deterministic = False\n",
    "    \n",
    "    # Disable 'benchmark' mode: Set this False if you want to measure running times more fairly\n",
    "    # Note: https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "    \n",
    "    # Faster Host to GPU copies with page-locked memory\n",
    "    use_pin_memory = True\n",
    "    \n",
    "    # Number of compute devices to be used for training\n",
    "    training_device_count = torch.cuda.device_count()\n",
    "\n",
    "    # CUDA libraries version information\n",
    "    print(\"CUDA Version: \" + str(torch.version.cuda))\n",
    "    print(\"cuDNN Version: \" + str(torch.backends.cudnn.version()))\n",
    "    print(\"CUDA Device Name: \" + str(torch.cuda.get_device_name()))\n",
    "    print(\"CUDA Capabilities: \"+ str(torch.cuda.get_device_capability()))\n",
    "    print(\"Number of CUDA devices: \"+ str(training_device_count))\n",
    "    \n",
    "else:\n",
    "    torch_device = torch.device('cpu')\n",
    "    use_pin_memory = False\n",
    "    \n",
    "    # Number of compute devices to be used for training\n",
    "    training_device_count = 1\n",
    "\n",
    "print()\n",
    "print(\"PyTorch device selected:\", torch_device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "H3txs5s7Q1UG"
   },
   "source": [
    "# Further pre-training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gEnUsBDUOLAm"
   },
   "source": [
    "## Load the BERT-base model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "W3TCqS-3OOIj"
   },
   "outputs": [],
   "source": [
    "tokenizer = transformers.AutoTokenizer.from_pretrained(\"bert-base-uncased\", cache_dir='./bert_base_cache')\n",
    "model = transformers.AutoModelForMaskedLM.from_pretrained(\"bert-base-uncased\", cache_dir='./bert_base_cache')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PrLTLsRFRUKK"
   },
   "source": [
    "## Load the Amazon electronics dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZpKL0urORmkm"
   },
   "outputs": [],
   "source": [
    "amazon = datasets.load_dataset(\n",
    "    './dataset_scripts/amazon_ucsd_reviews',\n",
    "    data_files={\n",
    "        'train': 'dataset_files/amazon_ucsd_reviews/Electronics.json.gz',\n",
    "    },\n",
    "    cache_dir='./dataset_cache')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2m9wZvndVCvy"
   },
   "outputs": [],
   "source": [
    "data_train = amazon['train']\n",
    "print(\"Number of training data (original):\", len(data_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WrB-ZoTz5-zb"
   },
   "outputs": [],
   "source": [
    "data_train_selected = data_train.shuffle(seed=random_seed).select(np.arange(0, int(len(data_train) * total_subset_proportion)))\n",
    "print(\"Number of training data (subset):\", len(data_train_selected))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "sQ_RaowTYC9X"
   },
   "outputs": [],
   "source": [
    "# Check out how individual data points look like\n",
    "print(data_train_selected[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WqyEwcDGSJjs"
   },
   "source": [
    "### Preprocessing: Encode the text with Tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tTVxGSY4oe7e"
   },
   "outputs": [],
   "source": [
    "train_dataset = data_train_selected.map(\n",
    "    lambda e: tokenizer(e['text'], truncation=True, padding='max_length', max_length=256),\n",
    "    remove_columns=data_train_selected.column_names,\n",
    "    batched=True, num_proc=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train-validation split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training set size after validation split\n",
    "new_train_dataset_size = int(len(train_dataset) * (1 - validation_dataset_proportion))\n",
    "new_valid_dataset_size = len(train_dataset) - new_train_dataset_size\n",
    "\n",
    "new_train_dataset = train_dataset.select(indices=np.arange(new_train_dataset_size))\n",
    "new_valid_dataset = train_dataset.select(indices=np.arange(new_train_dataset_size, new_train_dataset_size + new_valid_dataset_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Training dataset after split:\", len(new_train_dataset))\n",
    "print(\"Validation dataset after split:\", len(new_valid_dataset))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "p1N4PUdHZeIm"
   },
   "source": [
    "## Pre-train further"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PD6lfHh8mURq"
   },
   "source": [
    "### Training settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RFho5vyJgVaj"
   },
   "outputs": [],
   "source": [
    "# 'smart' masking\n",
    "collator = utils.DataCollatorForSmartMLM(\n",
    "    tokenizer=tokenizer,\n",
    "    mlm_probability=0.15,\n",
    "    prob_replace_with_mask=0.8,\n",
    "    prob_replace_with_random=0.1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# How many training steps would we have?\n",
    "approx_total_training_steps = len(new_train_dataset) // (per_device_train_batch_size * training_device_count) * num_train_epochs\n",
    "\n",
    "print(\"There will be approximately %d training steps.\" % approx_total_training_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8ps4XzQxmTgS"
   },
   "outputs": [],
   "source": [
    "training_args = transformers.TrainingArguments(\n",
    "    output_dir=os.path.join('.', 'progress', experiment_id, 'results'), # output directory\n",
    "    overwrite_output_dir=True,\n",
    "    num_train_epochs=num_train_epochs,              # total number of training epochs\n",
    "    per_device_train_batch_size=per_device_train_batch_size,\n",
    "    per_device_eval_batch_size=per_device_eval_batch_size,\n",
    "    evaluation_strategy='epoch',\n",
    "    logging_dir=os.path.join('.', 'progress', experiment_id, 'logs'), # directory for storing logs\n",
    "    logging_first_step=True,\n",
    "    weight_decay=weight_decay,               # strength of weight decay\n",
    "    seed=random_seed,\n",
    "    learning_rate=learning_rate,\n",
    "    fp16=True,\n",
    "    fp16_backend='amp',\n",
    "    prediction_loss_only=True,\n",
    "    load_best_model_at_end=True,\n",
    "    dataloader_num_workers=training_device_count * 2,\n",
    "    dataloader_pin_memory=use_pin_memory\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RQsRp9lMZstE"
   },
   "outputs": [],
   "source": [
    "print(training_args.n_gpu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1OQJ8IOIYHXb"
   },
   "outputs": [],
   "source": [
    "trainer = transformers.Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    data_collator=collator, # do the masking on the go\n",
    "    train_dataset=new_train_dataset,\n",
    "    eval_dataset=new_valid_dataset,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZTx5562OmXx5"
   },
   "source": [
    "### Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If checkpoint_path was given, print it out\n",
    "if checkpoint_path is not None:\n",
    "    print(\"Resuming from\", str(checkpoint_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-_v3lAPvb9DK"
   },
   "outputs": [],
   "source": [
    "trainer.train(resume_from_checkpoint=checkpoint_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wLj4Ico8vwhO"
   },
   "source": [
    "### Save the model to the local directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WLViPwdTvvxP"
   },
   "outputs": [],
   "source": [
    "trainer.save_model(os.path.join('.', 'trained_models', experiment_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SIQn4r1oVJp6"
   },
   "outputs": [],
   "source": [
    "tokenizer.save_pretrained(os.path.join('.', 'trained_models', experiment_id))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RNUCURCduqYa"
   },
   "source": [
    "## LM Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "u1XN8ot3us18"
   },
   "outputs": [],
   "source": [
    "eval_results = trainer.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zKkBZEpF07Ip"
   },
   "outputs": [],
   "source": [
    "print(eval_results)\n",
    "\n",
    "perplexity = np.exp(eval_results[\"eval_loss\"])\n",
    "\n",
    "print(perplexity)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "I_bAx1rL0KFb"
   },
   "source": [
    "## Playing with my own input sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hmIJszldveH6"
   },
   "outputs": [],
   "source": [
    "example = f\"\"\"The {tokenizer.mask_token} of {tokenizer.mask_token} is awful, but its {tokenizer.mask_token} is fantastic.\"\"\"\n",
    "\n",
    "example_encoded = tokenizer.encode(example, add_special_tokens=True, return_tensors=\"pt\").to(torch_device)\n",
    "\n",
    "# Let's decode this back just to see how they were actually encoded\n",
    "example_tokens = []\n",
    "\n",
    "for id in example_encoded[0]:\n",
    "    example_tokens.append(tokenizer.convert_ids_to_tokens(id.item()))\n",
    "\n",
    "print(example_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CbHup4-rxBSD"
   },
   "outputs": [],
   "source": [
    "example_prediction = model(example_encoded)\n",
    "\n",
    "example_prediction_argmax = torch.argmax(example_prediction[0], dim=-1)[0]\n",
    "\n",
    "print(example_prediction_argmax)\n",
    "\n",
    "print(tokenizer.decode(example_prediction_argmax))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EM3YetZAm3L-"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "lm_further_pretraining_bert_amazon_electronics.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
